replace tl.libdevice.llrint with tl.extra.cuda.libdevice.rint#1372
replace tl.libdevice.llrint with tl.extra.cuda.libdevice.rint#1372bjmsong wants to merge 2 commits intobitsandbytes-foundation:mainfrom
Conversation
TimDettmers
left a comment
There was a problem hiding this comment.
PR Review: #1372 — Replace tl.libdevice.llrint with tl.extra.cuda.libdevice.rint
[bug-fix] Triton API migration: updates tl.libdevice.llrint calls in 3 Triton quantization kernels to use the new tl.extra.cuda.libdevice path. The tl.libdevice namespace was removed in Triton 3.x, so this addresses a real compatibility issue.
Blocking issues (2):
-
Wrong function:
rintinstead ofllrint— The PR replacestl.libdevice.llrintwithtl.extra.cuda.libdevice.rint, but these are semantically different functions.llrintrounds to the nearest integer and returns an integer type (long long).rintrounds to the nearest integer but returns a float. Since the result is being stored to anint8output tensor, the practical difference may be minor (implicit float-to-int cast), buttl.extra.cuda.libdevice.llrintexists and is the correct 1:1 replacement. Usingrintis an unnecessary semantic change that could introduce subtle numerical differences in edge cases (e.g., values exactly at 0.5 boundaries, NaN/Inf handling). -
No regression test — This is a bug fix that changes quantization kernel behavior. There should be a test that exercises the Triton quantization path and verifies correctness. The
bitsandbytes/triton/kernels are guarded behindis_triton_available(), so existing tests may not exercise them. A minimal test confirmingquantize_rowwise,quantize_global, andquantize_columnwise_and_transposeproduce correct output with the updated API would be valuable.
Additional note: This PR is from September 2024 and is significantly behind main. A rebase will likely be needed. The Triton directory has had substantial changes since then (XPU triton optimizers, compatibility guards, etc.).
- Security: Clear (trivial API path change, no new imports, no suspicious patterns)
- Downstream impact: None (internal Triton kernels, not part of public API)
- Tests: Missing — no test covers the Triton quantization path change
- CI: Not triggered (fork PR — maintainer must approve workflow run)
- Serialization: Not affected
- Cross-PR conflicts: None detected
| abs_x = tl.abs(x) | ||
| max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) | ||
| output = tl.libdevice.llrint(127.0 * (x / max_val)) | ||
| output = tl.extra.cuda.libdevice.rint(127.0 * (x / max_val)) |
There was a problem hiding this comment.
tl.extra.cuda.libdevice.rint is not the correct 1:1 replacement for tl.libdevice.llrint. llrint rounds to nearest integer and returns an integer type; rint rounds to nearest integer but returns a float. Since tl.extra.cuda.libdevice.llrint exists in modern Triton, this should use tl.extra.cuda.libdevice.llrint instead to preserve the original semantics.
| x = tl.load(x_ptr + offsets, mask=mask) | ||
| absmax_inv = tl.load(absmax_inv_ptr) | ||
| output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) | ||
| output = tl.extra.cuda.libdevice.rint(127.0 * (x * absmax_inv)) |
There was a problem hiding this comment.
Same issue here: should be tl.extra.cuda.libdevice.llrint (not rint) to match the original semantics.
|
Closing in favor of #1871 which removes this functionality instead. |
No description provided.